//!
//! # Definitions of staking
//!

use crate::{
    common::{sha2_sha256, TmAddress},
    ethvm::tx::token::{Erc20Like, DECIMAL},
};
use ethereum_types::{H160, U256};
use ruc::*;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use vsdb::{basic::mapx_ord_rawkey::MapxOrdRawKey, MapxDkVs, MapxOrdVs, Vs};

// Used to measure the overall quality of a validator.
type Score = i64;
type Power = i64;
pub type Amount = u64;

pub type TmPubKeyBytes = Vec<u8>;

const SCORE_MAX: Score = 10_0000;
const SCORE_MIN_OFFLINE: Score = 1;
const VALIDATOR_CAP: u32 = 100;

pub const SCORE_MAX_U32: u32 = SCORE_MAX as u32;

const FEE_NORMAL: Amount = 10_u64.pow(DECIMAL.saturating_sub(8));
const FEE_FIRST_TIME: Amount = 10_u64.pow(DECIMAL);

#[derive(Vs, Clone, Debug, Deserialize, Serialize)]
pub struct State {
    // Can be negative.
    score_default: Score,

    // For offline behaviors,
    // the `score` should not be deducted to a negative/zero value,
    // that is, the minimum value is 1.
    score_min_offline: Score,

    // The maximum value is 10_0000.
    score_max: Score,

    // The capacity definition of formal validators.
    validator_cap: u32,

    // {Validator pubkey => Voting power}
    validators: MapxOrdVs<TmPubKeyBytes, Validator>,

    stakers: MapxDkVs<H160, TmPubKeyBytes, Amount>,

    // 1 DZK
    fee_first_time: Amount,

    // 0.0000_0001 DZK
    fee_normal: Amount,

    addr_to_pubkey: MapxOrdRawKey<TmPubKeyBytes>,
}

impl State {
    #[inline(always)]
    pub fn new() -> Self {
        Self {
            score_default: SCORE_MAX,
            score_min_offline: SCORE_MIN_OFFLINE,
            score_max: SCORE_MAX,
            validator_cap: VALIDATOR_CAP,
            validators: MapxOrdVs::new(),
            stakers: MapxDkVs::new(),
            fee_first_time: FEE_FIRST_TIME,
            fee_normal: FEE_NORMAL,
            addr_to_pubkey: MapxOrdRawKey::new(),
        }
    }

    #[inline(always)]
    pub fn set_initial_validators(
        &mut self,
        validators: BTreeMap<TmPubKeyBytes, Amount>,
    ) -> Result<()> {
        for (pk, v) in validators
            .into_iter()
            .map(|(pk, am)| (pk, Validator::new(am)))
        {
            self.validators.insert_ref(&pk, &v).c(d!())?;
            self.addr_to_pubkey.insert_ref(&pubkey_to_addr(&pk), &pk);
        }
        Ok(())
    }

    #[inline(always)]
    pub fn score_max(&self) -> Score {
        self.score_max
    }

    #[inline(always)]
    pub fn static_fee(&self) -> Amount {
        self.fee_normal
    }

    // return fee-used
    pub fn stake_to(
        &self,
        token: &Erc20Like,
        staker: H160,
        validator: &TmPubKeyBytes,
        amount: Amount,
        nonce: U256,
    ) -> Result<U256> {
        alt!(0 == amount, return Ok(U256::zero()));

        let mut acc = token.native_account(staker).unwrap_or_default();
        alt!(nonce != acc.nonce, return Err(eg!("bad nonce")));
        acc.nonce += U256::one();

        let fee = if self.validators.contains_key(validator) {
            self.fee_normal
        } else {
            self.fee_first_time
        };

        let old_am = self.stakers.get(&(&staker, validator)).unwrap_or(0);

        let balance_needed = amount.checked_add(fee).c(d!())?;
        acc.balance = acc
            .balance
            .checked_sub(balance_needed.into())
            .c(d!("insufficient balance"))?;

        let new_am = old_am.checked_add(amount).c(d!())?;

        let mut v = if let Some(v) = self.validators.get(validator) {
            v
        } else {
            Validator::default()
        };
        v.staking_total = v.staking_total.checked_add(amount).c(d!())?;

        token
            .accounts
            .insert(staker, acc)
            .c(d!())
            .and_then(|_| {
                self.stakers
                    .insert_ref(&(&staker, validator), &new_am)
                    .c(d!())
            })
            .and_then(|_| self.validators.insert_ref(validator, &v).c(d!()))?;

        self.addr_to_pubkey
            .insert_ref(&pubkey_to_addr(validator), validator);

        Ok(fee.into())
    }

    // return fee-used
    pub fn unstake_from(
        &self,
        token: &Erc20Like,
        staker: H160,
        validator: &TmPubKeyBytes,
        amount: Amount,
        nonce: U256,
    ) -> Result<U256> {
        alt!(0 == amount, return Ok(U256::zero()));

        let mut acc = token.native_account(staker).unwrap_or_default();
        alt!(nonce != acc.nonce, return Err(eg!("bad nonce")));
        acc.nonce += U256::one();

        let fee = U256::from(self.fee_normal);

        acc.balance = acc
            .balance
            .checked_sub(fee)
            .c(d!("insufficient balance"))?
            .checked_add(amount.into())
            .c(d!())?;

        let mut v = self
            .validators
            .get(validator)
            .c(d!("validator does not exist"))?;
        if 0 > v.score {
            return Err(eg!(
                "unstake from validators with negative scores are not allowed"
            ));
        }
        v.staking_total = v.staking_total.checked_sub(amount).c(d!())?;

        let old_am = self.stakers.get(&(&staker, validator)).c(d!())?;
        let new_am = old_am
            .checked_sub(amount)
            .c(d!("insufficient staking amount"))?;

        if 0 == new_am {
            self.stakers.remove(&(&staker, Some(validator))).c(d!())?;
        } else {
            self.stakers
                .insert_ref(&(&staker, validator), &new_am)
                .c(d!())?;
        }

        token.accounts.insert(staker, acc).c(d!())?;

        if 0 == v.staking_total {
            self.validators.remove(validator).c(d!())?;
        } else {
            self.validators.insert_ref(validator, &v).c(d!())?;
        }

        Ok(fee)
    }

    // return fee-used
    pub fn unstake_all(
        &self,
        token: &Erc20Like,
        staker: H160,
        nonce: U256,
    ) -> Result<U256> {
        let mut acc = token.native_account(staker).unwrap_or_default();
        alt!(nonce != acc.nonce, return Err(eg!("bad nonce")));
        acc.nonce += U256::one();

        let fee = U256::from(self.fee_normal);

        acc.balance = acc.balance.checked_sub(fee).c(d!("insufficient balance"))?;

        let mut validators = vec![];
        let mut cb = |(_, vpk): (H160, TmPubKeyBytes), am: Amount| -> Result<()> {
            let mut v = self.validators.get(&vpk).c(d!())?;
            if 0 <= v.score {
                acc.balance = acc.balance.checked_add(am.into()).c(d!())?;
                v.staking_total = v.staking_total.checked_sub(am).c(d!())?;
                validators.push((vpk, v));
            }
            Ok(())
        };
        self.stakers
            .iter_op_with_key_prefix(&mut cb, &staker)
            .c(d!())?;

        for (vpk, v) in validators.iter() {
            self.stakers.remove(&(&staker, Some(vpk))).c(d!())?;
            if 0 == v.staking_total {
                self.validators.remove(vpk).c(d!())?;
            } else {
                self.validators.insert_ref(vpk, v).c(d!())?;
            }
        }

        token.accounts.insert(staker, acc).c(d!()).map(|_| fee)
    }

    #[inline(always)]
    pub fn validator_get_score(&self, validator: &TmPubKeyBytes) -> Option<Score> {
        self.validators.get(validator).map(|v| v.score)
    }

    #[inline(always)]
    pub fn validator_get_staking_total(
        &self,
        validator: &TmPubKeyBytes,
    ) -> Option<Amount> {
        self.validators.get(validator).map(|v| v.staking_total)
    }

    #[inline(always)]
    pub fn validator_get_power(&self, validator: &TmPubKeyBytes) -> Option<Power> {
        self.validators.get(validator).map(|v| v.voting_power())
    }

    #[inline(always)]
    pub fn validator_in_formal_list(&self, validator: &TmPubKeyBytes) -> bool {
        self.validator_power_top_100()
            .iter()
            .any(|(pk, _)| pk == validator)
    }

    // formal validator list
    #[inline(always)]
    pub fn validator_power_top_100(&self) -> Vec<(TmPubKeyBytes, Power)> {
        self.validator_power_top_n(100)
    }

    // TODO: implement a pre-sorted cache to optimize performance
    #[inline(always)]
    pub fn validator_power_top_n(&self, n: usize) -> Vec<(TmPubKeyBytes, Power)> {
        let mut validators = self
            .validators
            .iter()
            .map(|(pk, v)| (pk, v.voting_power()))
            .filter(|(_, power)| 0 < *power)
            .collect::<Vec<_>>();
        validators.sort_unstable_by_key(|v| v.1);
        validators.split_off(validators.len().saturating_sub(n))
    }

    // unconditional increment
    #[inline(always)]
    pub fn validator_score_incr_by_new_block(&self) -> Result<()> {
        self.validator_score_incr_by_n(None, 1).c(d!())
    }

    #[inline(always)]
    pub fn validator_score_incr_by_online(
        &self,
        validator: &TmPubKeyBytes,
    ) -> Result<()> {
        self.validator_score_incr_by_n(Some(validator), 100).c(d!())
    }

    #[inline(always)]
    pub fn validator_score_incr_by_n(
        &self,
        validator: Option<&TmPubKeyBytes>,
        n: Score,
    ) -> Result<()> {
        if let Some(v) = validator {
            self.validators
                .get_mut(v)
                .map(|mut v| {
                    if v.score < self.score_max {
                        v.score_incr_by_n(n, self.score_max);
                    }
                })
                .c(d!())?;
        } else {
            for (pk, mut v) in self.validators.iter() {
                if v.score < self.score_max {
                    v.score_incr_by_n(n, self.score_max);
                    self.validators.insert(pk, v).c(d!())?;
                }
            }
        }
        Ok(())
    }

    #[inline(always)]
    pub fn validator_score_decr_by_offline(
        &self,
        validator: &TmPubKeyBytes,
    ) -> Result<()> {
        self.validators
            .get_mut(validator)
            .map(|mut v| {
                if v.score > self.score_min_offline {
                    v.score_decr_by_offline(self.score_min_offline);
                }
            })
            .c(d!())
    }

    // malicious behavior
    #[inline(always)]
    pub fn validator_score_decr_by_punishment(
        &self,
        validator: &TmPubKeyBytes,
    ) -> Result<()> {
        self.validators
            .get_mut(validator)
            .map(|mut v| v.score_decr_by_punishment(self.score_max))
            .c(d!())
    }

    #[inline(always)]
    pub fn validator_score_decr_by_n(
        &self,
        validator: &TmPubKeyBytes,
        n: Score,
        reason_is_offline: bool,
    ) -> Result<()> {
        let min_in_offline = if reason_is_offline {
            Some(self.score_min_offline)
        } else {
            None
        };
        self.validators
            .get_mut(validator)
            .map(|mut v| v.score_decr_by_n(n, min_in_offline))
            .c(d!())
    }

    pub fn apply_punishments(&self, punishments: Vec<Punishment>) -> Result<()> {
        for p in punishments.into_iter() {
            match p {
                Punishment::Malicious(validators) => {
                    for m in validators.iter() {
                        self.addr_to_pubkey.get(m).c(d!()).and_then(|mpk| {
                            self.validator_score_decr_by_punishment(&mpk).c(d!())
                        })?;
                    }
                }
                Punishment::Offline((offline_validators, online_validators)) => {
                    for v in offline_validators.iter() {
                        self.addr_to_pubkey.get(v).c(d!()).and_then(|vpk| {
                            self.validator_score_decr_by_offline(&vpk).c(d!())
                        })?;
                    }
                    for v in online_validators.iter() {
                        self.addr_to_pubkey.get(v).c(d!()).and_then(|vpk| {
                            self.validator_score_incr_by_online(&vpk).c(d!())
                        })?;
                    }
                }
            }
        }

        Ok(())
    }

    #[inline(always)]
    pub fn governance_with_each_block(
        &self,
        governances: Vec<Punishment>,
    ) -> Result<()> {
        self.validator_score_incr_by_new_block()
            .c(d!())
            .and_then(|_| self.apply_punishments(governances).c(d!()))
    }
}

impl Default for State {
    #[inline(always)]
    fn default() -> Self {
        Self::new()
    }
}

#[derive(Clone, Debug, Deserialize, Serialize)]
struct Validator {
    score: Score,

    // All tokens staked to it, including its own share.
    staking_total: Amount,
}

impl Validator {
    #[inline(always)]
    fn new(staking_total: Amount) -> Self {
        Self {
            score: SCORE_MAX,
            staking_total,
        }
    }

    #[inline(always)]
    fn voting_power(&self) -> Power {
        let score = alt!(self.score < 0, 0, self.score);
        Power::try_from(self.staking_total.saturating_mul(score as Amount))
            .unwrap_or(Power::MAX)
    }

    // // unconditional increment
    // #[inline(always)]
    // fn score_incr_by_one(&mut self, max: Score) {
    //     self.score_incr_by_n(1, max)
    // }

    // #[inline(always)]
    // fn score_incr_by_online(&mut self, max: Score) {
    //     self.score_incr_by_n(100, max)
    // }

    #[inline(always)]
    fn score_incr_by_n(&mut self, n: Score, max: Score) {
        let mut new_score = self.score.saturating_add(n);
        if new_score > max {
            new_score = max;
        }
        self.score = new_score;
    }

    #[inline(always)]
    fn score_decr_by_offline(&mut self, min_in_offline: Score) {
        self.score_decr_by_n(1000, Some(min_in_offline))
    }

    // malicious behavior
    #[inline(always)]
    fn score_decr_by_punishment(&mut self, max: Score) {
        self.score_decr_by_n(max.saturating_mul(100), None)
    }

    #[inline(always)]
    fn score_decr_by_n(&mut self, n: Score, min_in_offline: Option<Score>) {
        let mut new_score = self.score.saturating_sub(n);
        if let Some(min) = min_in_offline {
            if min > new_score {
                new_score = min;
            }
        }
        self.score = new_score;
    }
}

impl Default for Validator {
    fn default() -> Self {
        Self::new(0)
    }
}

#[non_exhaustive]
#[derive(Clone, Debug)]
pub enum Punishment {
    // malicious validators
    Malicious(Vec<TmAddress>),
    // (offline validators, online validators)
    Offline((Vec<TmAddress>, Vec<TmAddress>)),
}

#[inline(always)]
fn pubkey_to_addr(pk: &[u8]) -> TmAddress {
    sha2_sha256(&[pk])[..20].into()
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::ethvm::DzkAccount;
    use ethereum_types::U256;
    use vsdb::{VersionName, VsMgmt};

    #[test]
    fn score_mgmt() {
        let mut v = Validator::default();
        assert_eq!(v.voting_power(), 0);

        v.score_incr_by_n(100, 50);
        assert_eq!(v.score, 50);
        assert_eq!(v.voting_power(), 0);

        v.staking_total = 1;
        assert_eq!(v.voting_power(), 50);

        v.score_decr_by_offline(1);
        assert_eq!(v.score, 1);
        assert_eq!(v.voting_power(), 1);

        v.score_decr_by_punishment(50);
        assert_eq!(v.score, 1 - 100 * 50);
        assert_eq!(v.voting_power(), 0);
    }

    #[test]
    fn staking_mgmt() {
        let s = State::new();
        pnk!(s.version_create(VersionName(b"")));

        let v1 = vec![1; 20];
        let v2 = vec![2; 20];
        let v3 = vec![3; 20];

        let token = Erc20Like::default();
        pnk!(token.version_create(VersionName(b"")));

        let addr1 = H160::random();
        let mut balance1 = U256::from(10 * FEE_FIRST_TIME);
        let acc1 = DzkAccount::from_balance(balance1);
        let mut nonce1 = U256::zero();

        let addr2 = H160::random();
        let mut balance2 = balance1 * U256::from(2u32);
        let acc2 = DzkAccount::from_balance(balance2);
        let mut nonce2 = U256::zero();

        pnk!(token.accounts.insert(addr1, acc1));
        pnk!(token.accounts.insert(addr2, acc2));

        assert_eq!(100, s.validator_cap);

        pnk!(s.stake_to(&token, addr1, &v1, 1, nonce1));
        assert!(s.stake_to(&token, addr1, &v2, 2, nonce1).is_err()); // bad nonce
        nonce1 += U256::one();
        pnk!(s.stake_to(&token, addr1, &v2, 2, nonce1));
        nonce1 += U256::one();
        pnk!(s.stake_to(&token, addr1, &v3, 3, nonce1));
        nonce1 += U256::one();

        balance1 -= (3 * FEE_FIRST_TIME + 1 + 2 + 3).into();
        assert_eq!(token.native_balance(addr1), balance1);

        assert_eq!(s.validator_get_score(&v1).unwrap(), s.score_default);
        assert_eq!(s.validator_get_score(&v2).unwrap(), s.score_default);
        assert_eq!(s.validator_get_score(&v3).unwrap(), s.score_default);

        assert_eq!(s.validator_get_staking_total(&v1).unwrap(), 1);
        assert_eq!(s.validator_get_staking_total(&v2).unwrap(), 2);
        assert_eq!(s.validator_get_staking_total(&v3).unwrap(), 3);

        assert_eq!(
            s.validator_get_power(&v1).unwrap(),
            1 * (s.score_default as Power)
        );
        assert_eq!(
            s.validator_get_power(&v2).unwrap(),
            2 * (s.score_default as Power)
        );
        assert_eq!(
            s.validator_get_power(&v3).unwrap(),
            3 * (s.score_default as Power)
        );

        assert_eq!(s.validator_in_formal_list(&v1), true);
        assert_eq!(s.validator_in_formal_list(&v2), true);
        assert_eq!(s.validator_in_formal_list(&v3), true);

        pnk!(s.unstake_from(&token, addr1, &v3, 1, nonce1));
        nonce1 += U256::one();

        balance1 -= FEE_NORMAL.into();
        balance1 += 1.into();
        assert_eq!(token.native_balance(addr1), balance1);

        assert_eq!(s.validator_in_formal_list(&v1), true);
        assert_eq!(s.validator_in_formal_list(&v2), true);
        assert_eq!(s.validator_in_formal_list(&v3), true);

        pnk!(s.unstake_from(&token, addr1, &v3, 2, nonce1));
        nonce1 += U256::one();

        balance1 -= FEE_NORMAL.into();
        balance1 += 2.into();
        assert_eq!(token.native_balance(addr1), balance1);

        assert_eq!(s.validator_in_formal_list(&v1), true);
        assert_eq!(s.validator_in_formal_list(&v2), true);
        assert_eq!(s.validator_in_formal_list(&v3), false);

        pnk!(s.validator_score_decr_by_offline(&v1));
        pnk!(s.validator_score_decr_by_punishment(&v2));

        // v2 is untouched because its score is under zero
        pnk!(s.unstake_all(&token, addr1, nonce1));
        nonce1 += U256::one();

        balance1 -= FEE_NORMAL.into();
        balance1 += 1.into();
        assert_eq!(token.native_balance(addr1), balance1);

        assert_eq!(s.validator_in_formal_list(&v1), false);
        assert_eq!(s.validator_in_formal_list(&v2), false);
        assert_eq!(s.validator_in_formal_list(&v3), false);

        assert!(s.validator_get_power(&v1).is_none());
        assert_eq!(s.validator_get_power(&v2).unwrap(), 0);
        assert!(s.validator_get_power(&v3).is_none());

        pnk!(s.stake_to(&token, addr1, &v1, 10, nonce1));
        nonce1 += U256::one();
        pnk!(s.stake_to(&token, addr1, &v2, 20, nonce1));
        nonce1 += U256::one();
        pnk!(s.stake_to(&token, addr1, &v3, 30, nonce1));
        nonce1 += U256::one();

        balance1 -= (2 * FEE_FIRST_TIME + FEE_NORMAL + 10 + 20 + 30).into();
        assert_eq!(token.native_balance(addr1), balance1);

        pnk!(s.validator_score_decr_by_offline(&v1));
        pnk!(s.validator_score_decr_by_punishment(&v2));

        assert_eq!(s.validator_get_staking_total(&v1).unwrap(), 10);
        assert_eq!(s.validator_get_staking_total(&v2).unwrap(), 22);
        assert_eq!(s.validator_get_staking_total(&v3).unwrap(), 30);

        assert_eq!(
            s.validator_get_power(&v1).unwrap(),
            10 * ((s.score_default - 1000) as Power)
        );
        assert_eq!(s.validator_get_power(&v2).unwrap(), 0);
        assert_eq!(
            s.validator_get_power(&v3).unwrap(),
            30 * (s.score_default as Power)
        );

        pnk!(s.stake_to(&token, addr2, &v1, 0, nonce2));
        pnk!(s.stake_to(&token, addr2, &v2, 0, nonce2));
        pnk!(s.stake_to(&token, addr2, &v3, 0, nonce2));

        // no fee charged
        assert_eq!(token.native_balance(addr2), balance2);

        assert_eq!(s.validator_get_staking_total(&v1).unwrap(), 10);
        assert_eq!(s.validator_get_staking_total(&v2).unwrap(), 22);
        assert_eq!(s.validator_get_staking_total(&v3).unwrap(), 30);

        assert_eq!(
            s.validator_get_power(&v1).unwrap(),
            10 * ((s.score_default - 1000) as Power)
        );
        assert_eq!(s.validator_get_power(&v2).unwrap(), 0);
        assert_eq!(
            s.validator_get_power(&v3).unwrap(),
            30 * (s.score_default as Power)
        );

        (0..100).for_each(|_| {
            pnk!(s.validator_score_incr_by_new_block());
        });

        assert_eq!(
            s.validator_get_power(&v1).unwrap(),
            10 * ((s.score_default - 900) as Power)
        );
        assert_eq!(s.validator_get_power(&v2).unwrap(), 0);
        assert_eq!(
            s.validator_get_power(&v3).unwrap(),
            30 * (s.score_default as Power)
        );

        (0..2000).for_each(|_| {
            pnk!(s.validator_score_incr_by_new_block());
        });

        assert_eq!(
            s.validator_get_power(&v1).unwrap(),
            10 * (s.score_default as Power)
        );
        assert_eq!(s.validator_get_power(&v2).unwrap(), 0);
        assert_eq!(
            s.validator_get_power(&v3).unwrap(),
            30 * (s.score_default as Power)
        );

        pnk!(s.stake_to(&token, addr2, &v1, 10, nonce2));
        nonce2 += U256::one();
        pnk!(s.stake_to(&token, addr2, &v2, 100, nonce2));
        nonce2 += U256::one();
        pnk!(s.stake_to(&token, addr2, &v3, 1000, nonce2));
        nonce2 += U256::one();

        balance2 -= (3 * FEE_NORMAL + 10 + 100 + 1000).into();
        assert_eq!(token.native_balance(addr2), balance2);

        pnk!(s.unstake_from(&token, addr2, &v2, 0, nonce2));
        // no fee charged
        assert_eq!(token.native_balance(addr2), balance2);

        assert!(s.unstake_from(&token, addr2, &v2, 1, nonce2).is_err());
        // no fee charged
        assert_eq!(token.native_balance(addr2), balance2);

        assert!(s.unstake_from(&token, addr2, &v2, 1000, nonce2).is_err());
        // no fee charged
        assert_eq!(token.native_balance(addr2), balance2);
    }
}
